#!/usr/bin/python
import code,pickle,sys,os,re
import random as pyrandom
from pylab import *
from optparse import OptionParser
import ocrolib
from ocrolib import quant,utils,dbutils

parser = OptionParser("""
usage: %prog [options] chars.db output.db

""")

parser.add_option("-D","--display",help="display chars",action="store_true")
parser.add_option("-v","--verbose",help="verbose output",action="store_true")
parser.add_option("-t","--table",help="table name",default="chars")
parser.add_option("-e","--epsilon",help="epsilon",type=float,default=0.2)
parser.add_option("-o","--overwrite",help="overwrite output if it exists",action="store_true")
parser.add_option("-N","--nsamples",help="number of samples for clustering",type=int,default=100000)
parser.add_option("-K","--nbuckets",help="number of buckets",type=int,default=100)
parser.add_option("-L","--limit",help="total number of samples",type=int,default=10000000)


def showgrid(data,r=None,d=6):
    clf()
    print "showgrid",data.shape,amin(data),amax(data)
    gray()
    if r is None: r = int(sqrt(data.shape[1]))
    for i in range(min(len(data),d*d)):
        subplot(d,d,i+1)
        imshow(data[i].reshape(r,r))
    ginput(1,timeout=1)

class Hist(dict):
    def add(self,x):
        if self.get(x) is None:
            self[x] = 1
        else:
            self[x] = self[x]+1
    def cls(self):
        return max(list(self.items()),key=lambda x:x[1])

class EpsNet:
    def __init__(self,eps=0.05):
        self.eps = eps
        self.data = None
        self.classes = []
        self.counts = []
        self.total = 0
    def add(self,v,cls=None):
        assert v.ndim==1
        assert amin(v)>-2 and amax(v)<2,"input vectors should be nearly normalized"
        self.total += 1
        if self.data is None:
            self.data = array(v,'f').reshape(1,len(v))
            h = Hist()
            h.add(cls)
            self.classes.append(h)
            self.counts.append(1)
            return 0
        bucket,d = quant.argmindist2(v,self.data)
        # print ">>>",bucket,d,self.data.shape
        if d>self.eps:
            self.data = concatenate([self.data,v.reshape(1,len(v))])
            h = Hist()
            h.add(cls)
            self.classes.append(h)
            self.counts.append(1)
            return len(self.counts)-1
        else:
            self.data[bucket,:] = (self.counts[bucket]*self.data[bucket,:] + v) * 1.0/(self.counts[bucket]+1)
            self.classes[bucket].add(cls)
            self.counts[bucket] += 1
            return bucket
    def cls(self,i):
        return self.classes[i].cls()
    def stats(self):
        return " ".join([str(self.total),str(len(self.classes))])
    def items(self):
        for i in range(len(self.classes)):
            v = self.data[i,:]
            image = array(v/amax(v)*255.0,'B')
            r = int(sqrt(len(image)))
            image.shape = (r,r)
            cls,count = self.cls(i)
            classes = repr(self.classes[i])
            yield utils.Record(image=image,cls=cls,count=count,classes=classes)

(options,args) = parser.parse_args()

if len(args)!=2:
    parser.print_help()
    sys.exit(0)

if os.path.exists(args[1]):
    if not options.overwrite:
        sys.stderr.write("%s: already exists\n"%args[1])
        sys.exit(1)
    else:
        os.unlink(output)

extractor = ocrolib.BboxFE()
def extract(v):
    v /= sqrt(sum(v**2))
    v = extractor.extract(v)
    return v

ion(); show(); gray()

db = dbutils.chardb(args[0])
dbutils.table(db,"chars",image="blob",cluster="integer",cls="integer",count="integer")
ids = list(dbutils.ids(db,"chars"))

ids = ids[:options.limit]
print "total",len(ids)
sample = pyrandom.sample(ids,min(options.nsamples,len(ids)))

data = []
for id in sample:
    row = dbutils.row_query(db,"select * from chars where id=?",id)
    data.append(extract(row.float_image()).ravel())
data = array(data,'f')

showgrid(data)

print "sampled",len(data)
nbuckets = options.nbuckets
means,counts = quant.kmeans(data,k=nbuckets)
print counts

showgrid(means)

clusterers = [EpsNet(options.epsilon) for i in range(nbuckets)]

total = 0
for id in ids:
    row = dbutils.get(db,"chars",id)
    v = extract(row.float_image()).ravel()
    bucket,d = quant.argmindist2(v,means)
    cluster = clusterers[bucket].add(v,row.cls)
    cluster_id = cluster*nbuckets+bucket
    dbutils.put(db,"chars",id,cluster=int(cluster_id))
    total+=1
    if total%1000==0:
        print "#",total
        print [c.stats() for c in clusterers][:10]
db.commit()
db.close()

out = dbutils.chardb(args[1])
dbutils.table(out,"clusters",image="blob",cls="text",count="integer",classes="text",cluster="integer")
for bucket in range(len(clusterers)):
    c = clusterers[bucket]
    clusters = list(c.items())
    for cluster in range(len(clusters)):
        r = clusters[cluster]
        cluster_id = cluster*nbuckets+bucket
        dbutils.insert(out,"clusters",image=dbutils.image2blob(r.image),cls=r.cls,
                       count=r.count,classes=r.classes,cluster=cluster)

out.commit()
out.close()
